#include <sys/mman.h>
#include "dnnacl_ops_kernel_info_store.h"
#include "common/debug/log.h"
#include "common/fmk_error_codes.h"
#include "serialize.h"
#include "fd_manager_ddk.h"
#include "common/util.h"
#include "graph/serialize_factory.h"
#include "rpc/rpc_request_client.h"
#include "dnnacl_cmd.h"
#include "framework/graph/core/node/node_spec.h"

using namespace std;
using namespace hiai;
using namespace hiai::rpc;
using namespace hiai::error;

namespace dnnacl {
DnnaclOpsKernelInfoStore::DnnaclOpsKernelInfoStore(const string& clName)
{
    clName_ = clName;
}

DnnaclOpsKernelInfoStore::~DnnaclOpsKernelInfoStore()
{
}

Status DnnaclOpsKernelInfoStore::Initialize(const map<string, string>& options)
{
    return SUCCESS;
}

Status DnnaclOpsKernelInfoStore::Finalize()
{
    return SUCCESS;
}

void DnnaclOpsKernelInfoStore::DestroyFd(
    std::vector<hiai::rpc::RpcHandle>& in1, std::vector<hiai::rpc::RpcHandle>& in2) const
{
    for (auto& iter : in1) {
        FdManager::DestroyFd(iter.fd);
    }

    for (auto& iter : in2) {
        FdManager::DestroyFd(iter.fd);
    }
}

Status DnnaclOpsKernelInfoStore::Serialize(
    const ge::ComputeGraphPtr& computeGraph, string& clName, vector<RpcHandle>& commContext) const
{
    auto serialize = ModelSerializeFactory::Instance()->CreateModelSerialize(SerializeType::SERIALIZER_TYPE_PROTOBUFF);
    if (serialize == nullptr) {
        return FAILURE;
    }
    ge::Buffer bufferGraph = serialize->SerializeGraph(computeGraph);
    int graphSize = bufferGraph.GetSize();

    int graphFd = FdManager::CreateFdAndFlush(ION_ALLOC_NAME, (size_t)graphSize, bufferGraph.GetData());
    DOMI_CHECK_GE_WITH_RETURN(graphFd, 0, FAILURE);
    RpcHandle context{graphFd, graphSize};
    commContext.emplace_back(context);

    int fd = FdManager::CreateFdAndFlush(ION_ALLOC_NAME, (size_t)clName.size(), clName.data());
    DOMI_CHECK_GE_WITH_RETURN(fd, 0, FAILURE);

    context.fd = fd;
    context.size = clName.size();
    commContext.emplace_back(context);

    return SUCCESS;
}
Status DnnaclOpsKernelInfoStore::Serialize(string& clName, vector<RpcHandle>& commContext) const
{
    RpcHandle context;
    int fd = FdManager::CreateFdAndFlush(ION_ALLOC_NAME, (size_t)clName.size(), clName.data());
    DOMI_CHECK_GE_WITH_RETURN(fd, 0, FAILURE);

    context.fd = fd;
    context.size = clName.size();
    commContext.emplace_back(context);

    return SUCCESS;
}

Status DnnaclOpsKernelInfoStore::UnSerialize(const vector<RpcHandle>& commContext, map<string, ge::OpInfo>& infos) const
{
    return SUCCESS;
}

Status DnnaclOpsKernelInfoStore::UnSerialize(const vector<RpcHandle>& commContext, vector<string>& result) const
{
    return SUCCESS;
}

void DnnaclOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, ge::OpInfo>& infos) const
{
    vector<RpcHandle> input;
    vector<RpcHandle> output;

    string name = clName_;
    if (Serialize(name, input) != SUCCESS) {
        FMK_LOGE("Serialize failed");
        return;
    }
    int32_t res =
        RpcRequestClient::GetInstance().Execute(0, 0, CMD_DNNACL_OPSKERNELINFOSTORE_GETALLOPSINFO, input, output);
    if (res != SUCCESS) {
        DestroyFd(input, output);
        FMK_LOGE("Execute fail");
        return;
    }

    if (UnSerialize(output, infos) != SUCCESS) {
        DestroyFd(input, output);
        FMK_LOGE("UnSerialize fail");
        return;
    }

    opInfos_.clear();
    opInfos_.insert(infos.begin(), infos.end());
    DestroyFd(input, output);
}

vector<string> DnnaclOpsKernelInfoStore::CheckSupported(const ge::ComputeGraphPtr computeGraph) const
{
    std::vector<std::string> supportedOpName;
    if (computeGraph.get() == nullptr) {
        FMK_LOGE("ComputeGraph is null");
        return supportedOpName;
    }

    const std::set<std::string> opListV310_ = {
        "Data",
        "Convolution",
        "Correlation",
        "Correlation_V2",
        "ConvTranspose",
        "PoolingD",
        "Eltwise",
        "ReLU",
        "ReLU6",
        "Sigmoid",
        "LeakyRelu",
        "AbsVal",
        "TanH",
        "PReLU",
        "BNInference",
        "FusionBatchNorm",
        "Scale",
        "FullyConnection",
        "Softmax",
        "SSDPriorBox",
        "Power",
        "ConcatD",
    };

    const std::set<std::string>* opList = &opListV310_;
    (void)computeGraph->ROLE(GraphListWalker).WalkAllNodes([](const ge::Node& node) {
        auto iter = opList->find(node->ROLE(NodeSpec).Type());
        if (iter != opList->end()) {
            supportedOpName.push_back(node->ROLE(NodeSpec).Name());
        }
        return SUCCESS;
    });

    FMK_LOGI("supported op size:%zu", supportedOpName.size());
    return supportedOpName;
}
} // namespace dnnacl
